{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "initial_id",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-24T05:12:02.572413Z",
     "start_time": "2025-01-24T05:12:02.565568Z"
    },
    "collapsed": true,
    "execution": {
     "iopub.execute_input": "2025-01-24T12:01:14.195219Z",
     "iopub.status.busy": "2025-01-24T12:01:14.195068Z",
     "iopub.status.idle": "2025-01-24T12:01:19.473853Z",
     "shell.execute_reply": "2025-01-24T12:01:19.473221Z",
     "shell.execute_reply.started": "2025-01-24T12:01:14.195199Z"
    },
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sys.version_info(major=3, minor=10, micro=14, releaselevel='final', serial=0)\n",
      "matplotlib 3.10.0\n",
      "numpy 1.26.4\n",
      "pandas 2.2.3\n",
      "sklearn 1.6.0\n",
      "torch 2.5.1+cu124\n",
      "cuda:0\n"
     ]
    }
   ],
   "source": [
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import sklearn\n",
    "import pandas as pd\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "from tqdm.auto import tqdm\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "print(sys.version_info)\n",
    "for module in mpl, np, pd, sklearn, torch:\n",
    "    print(module.__name__, module.__version__)\n",
    "\n",
    "device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "print(device)\n",
    "\n",
    "seed = 42"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a3c3b37d6bda167b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-24T05:12:02.648919Z",
     "start_time": "2025-01-24T05:12:02.643102Z"
    },
    "execution": {
     "iopub.execute_input": "2025-01-24T12:01:19.475323Z",
     "iopub.status.busy": "2025-01-24T12:01:19.474975Z",
     "iopub.status.idle": "2025-01-24T12:01:19.481768Z",
     "shell.execute_reply": "2025-01-24T12:01:19.481271Z",
     "shell.execute_reply.started": "2025-01-24T12:01:19.475301Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "length 1115394\n",
      "First Citizen:\n",
      "Before we proceed any further, hear me speak.\n",
      "\n",
      "All:\n",
      "Speak, speak.\n",
      "\n",
      "First Citizen:\n",
      "You\n"
     ]
    }
   ],
   "source": [
    "with open(\"./shakespeare.txt\", \"r\", encoding=\"utf8\") as file:\n",
    "    text = file.read()\n",
    "\n",
    "print(\"length\", len(text))\n",
    "print(text[0:100])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e8a6265b1b95715",
   "metadata": {},
   "source": [
    "## 构造字典"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ed24e634656c7dfb",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-24T05:12:02.680297Z",
     "start_time": "2025-01-24T05:12:02.667925Z"
    },
    "execution": {
     "iopub.execute_input": "2025-01-24T12:01:19.482422Z",
     "iopub.status.busy": "2025-01-24T12:01:19.482257Z",
     "iopub.status.idle": "2025-01-24T12:01:19.503819Z",
     "shell.execute_reply": "2025-01-24T12:01:19.503343Z",
     "shell.execute_reply.started": "2025-01-24T12:01:19.482403Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "65\n",
      "['\\n', ' ', '!', '$', '&', \"'\", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']\n"
     ]
    }
   ],
   "source": [
    "# 去重，留下独立字符，并排序（）\n",
    "vocab = sorted(set(text))\n",
    "print(len(vocab))\n",
    "print(vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d4ebd525b8209c99",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-24T05:12:02.702715Z",
     "start_time": "2025-01-24T05:12:02.698303Z"
    },
    "execution": {
     "iopub.execute_input": "2025-01-24T12:01:19.504526Z",
     "iopub.status.busy": "2025-01-24T12:01:19.504365Z",
     "iopub.status.idle": "2025-01-24T12:01:19.507484Z",
     "shell.execute_reply": "2025-01-24T12:01:19.507024Z",
     "shell.execute_reply.started": "2025-01-24T12:01:19.504508Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'\\n': 0, ' ': 1, '!': 2, '$': 3, '&': 4, \"'\": 5, ',': 6, '-': 7, '.': 8, '3': 9, ':': 10, ';': 11, '?': 12, 'A': 13, 'B': 14, 'C': 15, 'D': 16, 'E': 17, 'F': 18, 'G': 19, 'H': 20, 'I': 21, 'J': 22, 'K': 23, 'L': 24, 'M': 25, 'N': 26, 'O': 27, 'P': 28, 'Q': 29, 'R': 30, 'S': 31, 'T': 32, 'U': 33, 'V': 34, 'W': 35, 'X': 36, 'Y': 37, 'Z': 38, 'a': 39, 'b': 40, 'c': 41, 'd': 42, 'e': 43, 'f': 44, 'g': 45, 'h': 46, 'i': 47, 'j': 48, 'k': 49, 'l': 50, 'm': 51, 'n': 52, 'o': 53, 'p': 54, 'q': 55, 'r': 56, 's': 57, 't': 58, 'u': 59, 'v': 60, 'w': 61, 'x': 62, 'y': 63, 'z': 64}\n"
     ]
    }
   ],
   "source": [
    "#每个字符都编好号，enumerate对每一个位置编号，生成的是列表中是元组，下面字典生成式\n",
    "char2idx = {char: idx for idx, char in enumerate(vocab)}\n",
    "print(char2idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b9f0c2607c3d284a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-24T05:12:02.720411Z",
     "start_time": "2025-01-24T05:12:02.715723Z"
    },
    "execution": {
     "iopub.execute_input": "2025-01-24T12:01:19.508404Z",
     "iopub.status.busy": "2025-01-24T12:01:19.508000Z",
     "iopub.status.idle": "2025-01-24T12:01:19.511235Z",
     "shell.execute_reply": "2025-01-24T12:01:19.510759Z",
     "shell.execute_reply.started": "2025-01-24T12:01:19.508384Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['\\n' ' ' '!' '$' '&' \"'\" ',' '-' '.' '3' ':' ';' '?' 'A' 'B' 'C' 'D' 'E'\n",
      " 'F' 'G' 'H' 'I' 'J' 'K' 'L' 'M' 'N' 'O' 'P' 'Q' 'R' 'S' 'T' 'U' 'V' 'W'\n",
      " 'X' 'Y' 'Z' 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j' 'k' 'l' 'm' 'n' 'o'\n",
      " 'p' 'q' 'r' 's' 't' 'u' 'v' 'w' 'x' 'y' 'z']\n"
     ]
    }
   ],
   "source": [
    "# 把vocab从列表变为ndarray\n",
    "idx2char = np.array(vocab)\n",
    "print(idx2char)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "493c82cc4b1ab425",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-24T05:12:02.829284Z",
     "start_time": "2025-01-24T05:12:02.721415Z"
    },
    "execution": {
     "iopub.execute_input": "2025-01-24T12:01:19.513956Z",
     "iopub.status.busy": "2025-01-24T12:01:19.513501Z",
     "iopub.status.idle": "2025-01-24T12:01:19.631499Z",
     "shell.execute_reply": "2025-01-24T12:01:19.630870Z",
     "shell.execute_reply.started": "2025-01-24T12:01:19.513925Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1115394,)\n",
      "1115394\n",
      "[18 47 56 57 58  1 15 47 58 47]\n",
      "First Citi\n"
     ]
    }
   ],
   "source": [
    "# 把字符都转换为id\n",
    "text_as_int = np.array([char2idx[c] for c in text])\n",
    "print(text_as_int.shape)\n",
    "print(len(text_as_int))\n",
    "print(text_as_int[0:10])\n",
    "print(text[0:10])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12e5e79707eddd0d",
   "metadata": {},
   "source": [
    "## 把莎士比亚文集分成一个一个的样本"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f1afc949c019504a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-24T05:12:02.837714Z",
     "start_time": "2025-01-24T05:12:02.830290Z"
    },
    "execution": {
     "iopub.execute_input": "2025-01-24T12:01:19.632715Z",
     "iopub.status.busy": "2025-01-24T12:01:19.632218Z",
     "iopub.status.idle": "2025-01-24T12:01:19.639541Z",
     "shell.execute_reply": "2025-01-24T12:01:19.638964Z",
     "shell.execute_reply.started": "2025-01-24T12:01:19.632683Z"
    }
   },
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "\n",
    "class CharDataset(Dataset):\n",
    "    def __init__(self, text_as_int, seq_length):\n",
    "        # text_as_int是字符的id列表，seq_length是每个样本的长度\n",
    "        self.sub_len = seq_length + 1  # 因为要预测下一个字符，所以要多预留一个位置\n",
    "        self.text_as_int = text_as_int\n",
    "        self.num_seq = len(text_as_int) // self.sub_len  # 样本的数量\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        # index是样本的索引，返回的是一个样本，比如第一个，就是0-100的字符,总计101个字符\n",
    "        # 返回一个长度为 self.sub_len 的序列片段。\n",
    "        return self.text_as_int[index * self.sub_len:(index + 1) * self.sub_len]\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.num_seq\n",
    "\n",
    "\n",
    "def collate_fct(batch):\n",
    "    # batch是一个列表,列表中的每一个元素是一个样本，有101个字符，前100个是输入，后100个是输出\n",
    "    src_list = []  # 输入\n",
    "    trg_list = []  # 输出\n",
    "    for part in batch:\n",
    "        src_list.append(part[:-1])  # 输入序列\n",
    "        trg_list.append(part[1:])  # 输出序列\n",
    "\n",
    "    src_ndarray = np.array(src_list)  # 把列表转换为ndarray\n",
    "    trg_ndarray = np.array(trg_list)  # 把列表转换为ndarray\n",
    "    # 返回的是一个元组，元组中的每一个元素是一个torch.Tensor\n",
    "    return torch.Tensor(src_ndarray).to(dtype=torch.int64), torch.Tensor(trg_ndarray).to(dtype=torch.int64)\n",
    "\n",
    "\n",
    "# 每个样本的长度是101，也就是100个字符+1个结束符\n",
    "train_ds = CharDataset(text_as_int, seq_length=100)\n",
    "# shuffle=False表示不打乱顺序\n",
    "train_loader = DataLoader(train_ds, batch_size=64,\n",
    "                          collate_fn=collate_fct, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "73efe8a799e66576",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-24T05:12:02.844993Z",
     "start_time": "2025-01-24T05:12:02.838720Z"
    },
    "execution": {
     "iopub.execute_input": "2025-01-24T12:01:19.640698Z",
     "iopub.status.busy": "2025-01-24T12:01:19.640302Z",
     "iopub.status.idle": "2025-01-24T12:01:19.682485Z",
     "shell.execute_reply": "2025-01-24T12:01:19.681926Z",
     "shell.execute_reply.started": "2025-01-24T12:01:19.640669Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([64, 100])\n",
      "torch.Size([64, 100])\n"
     ]
    }
   ],
   "source": [
    "for datas, labels in train_loader:\n",
    "    print(datas.shape)\n",
    "    print(labels.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a771010f6cc72a4",
   "metadata": {},
   "source": [
    "## 定义模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f5c94a2a0ba22d84",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-24T05:12:02.852945Z",
     "start_time": "2025-01-24T05:12:02.847006Z"
    },
    "execution": {
     "iopub.execute_input": "2025-01-24T12:01:19.683527Z",
     "iopub.status.busy": "2025-01-24T12:01:19.683126Z",
     "iopub.status.idle": "2025-01-24T12:01:19.688379Z",
     "shell.execute_reply": "2025-01-24T12:01:19.687931Z",
     "shell.execute_reply.started": "2025-01-24T12:01:19.683497Z"
    }
   },
   "outputs": [],
   "source": [
    "class CharLSTM(nn.Module):\n",
    "    def __init__(self, vocab_size, embedding_dim=256, hidden_dim=1024):\n",
    "        super().__init__()\n",
    "        # 词嵌入层\n",
    "        self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
    "        # 循环神经网络层\n",
    "        # num_layers: 隐藏层的层数\n",
    "        # batch_first: 输入输出的格式，默认是(seq,batch,feature)，设置为True后为(batch,seq,feature)\n",
    "        # bidirectional: 是否使用双向循环神经网络\n",
    "        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)\n",
    "        # 输出层\n",
    "        self.fc = nn.Linear(hidden_dim, vocab_size)\n",
    "\n",
    "    def forward(self, x, hidden=None):\n",
    "        # [bacth_size,seq_len,vocab_size]->[batch_size,seq_len,embedding_dim]\n",
    "        # [batch_size, 100,65]->[batch_size,100,256]\n",
    "        x = self.embedding(x)  # 词嵌入\n",
    "        # [batch_size,seq_len,embedding_dim]->[batch_size,seq_len,hidden_dim]\n",
    "        # [batch_size,100,256]->[batch_size, 100,1024]\n",
    "        output, hidden = self.lstm(x, hidden)  # 循环神经网络\n",
    "        #这里和02的差异是没有只拿最后一个输出，而是把所有的输出都拿出来了\n",
    "        # [batch_size,seq_len,hidden_dim]->[batch_size,seq_len,vocab_size]\n",
    "        # [batch_size,100,1024]->[batch_size,100,65]\n",
    "        x = self.fc(output)\n",
    "        return x, hidden  # x的shape是(batch_size, seq_len, vocab_size)\n",
    "\n",
    "\n",
    "vocab_size = len(vocab)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "55f024c0091d7cc",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-24T05:12:02.888466Z",
     "start_time": "2025-01-24T05:12:02.853950Z"
    },
    "execution": {
     "iopub.execute_input": "2025-01-24T12:01:19.689287Z",
     "iopub.status.busy": "2025-01-24T12:01:19.688882Z",
     "iopub.status.idle": "2025-01-24T12:01:19.733765Z",
     "shell.execute_reply": "2025-01-24T12:01:19.733327Z",
     "shell.execute_reply.started": "2025-01-24T12:01:19.689267Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 单层单向 RNN \n",
      "            embedding.weight            paramerters num: 16640\n",
      "           lstm.weight_ih_l0            paramerters num: 1048576\n",
      "           lstm.weight_hh_l0            paramerters num: 4194304\n",
      "            lstm.bias_ih_l0             paramerters num: 4096\n",
      "            lstm.bias_hh_l0             paramerters num: 4096\n",
      "               fc.weight                paramerters num: 66560\n",
      "                fc.bias                 paramerters num: 65\n"
     ]
    }
   ],
   "source": [
    "print(\" 单层单向 RNN \")\n",
    "for key, value in CharLSTM(vocab_size).named_parameters():\n",
    "    print(f\"{key:^40}paramerters num: {np.prod(value.shape)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ab858cede55e136",
   "metadata": {},
   "source": [
    "## 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "9774a1ebaacafd67",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-24T05:12:02.895436Z",
     "start_time": "2025-01-24T05:12:02.889479Z"
    },
    "execution": {
     "iopub.execute_input": "2025-01-24T12:01:19.734514Z",
     "iopub.status.busy": "2025-01-24T12:01:19.734289Z",
     "iopub.status.idle": "2025-01-24T12:01:19.739732Z",
     "shell.execute_reply": "2025-01-24T12:01:19.739184Z",
     "shell.execute_reply.started": "2025-01-24T12:01:19.734496Z"
    }
   },
   "outputs": [],
   "source": [
    "class SaveCheckpointsCallback:\n",
    "    def __init__(self, save_dir, save_step=5000, save_best_only=True):\n",
    "        self.save_dir = save_dir  # 保存路径\n",
    "        self.save_step = save_step  # 保存步数\n",
    "        self.save_best_only = save_best_only  # 是否只保存最好的模型\n",
    "        self.best_metric = -1  # 最好的指标，指标不可能为负数，所以初始化为-1\n",
    "        # 创建保存路径\n",
    "        if not os.path.exists(self.save_dir):  # 如果不存在保存路径，则创建\n",
    "            os.makedirs(self.save_dir)\n",
    "\n",
    "    # 对象被调用时：当你将对象像函数一样调用时，Python 会自动调用 __call__ 方法。\n",
    "    # state_dict() 返回模型参数的字典，包括模型参数和优化器参数\n",
    "    # metric 是指标，可以是验证集的准确率，也可以是其他指标\n",
    "    def __call__(self, step, state_dict, metric=None):\n",
    "        if step % self.save_step > 0:\n",
    "            return  # 不是保存步数，则直接返回\n",
    "\n",
    "        if self.save_best_only:\n",
    "            assert metric is not None  # 必须传入metric\n",
    "            if metric >= self.best_metric:  # 如果当前指标大于最好的指标\n",
    "                # save checkpoint\n",
    "                # 保存最好的模型，覆盖之前的模型，不保存step，只保存state_dict，即模型参数，不保存优化器参数\n",
    "                torch.save(state_dict, os.path.join(self.save_dir, \"05_text_generation.ckpt\"))\n",
    "                self.best_metric = metric  # 更新最好的指标\n",
    "        else:\n",
    "            # 保存模型\n",
    "            torch.save(state_dict, os.path.join(self.save_dir, f\"{step}.ckpt\"))\n",
    "            # 保存每个step的模型，不覆盖之前的模型，保存step，保存state_dict，即模型参数，不保存优化器参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "16fb5bafe219f6d1",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-24T05:12:02.902817Z",
     "start_time": "2025-01-24T05:12:02.896449Z"
    },
    "execution": {
     "iopub.execute_input": "2025-01-24T12:01:19.740938Z",
     "iopub.status.busy": "2025-01-24T12:01:19.740540Z",
     "iopub.status.idle": "2025-01-24T12:01:19.747824Z",
     "shell.execute_reply": "2025-01-24T12:01:19.747345Z",
     "shell.execute_reply.started": "2025-01-24T12:01:19.740909Z"
    }
   },
   "outputs": [],
   "source": [
    "def training(model,\n",
    "             train_loader,\n",
    "             epoch,\n",
    "             loss_fct,\n",
    "             optimizer,\n",
    "             save_ckpt_callback=None,\n",
    "             stateful=False  # 想用stateful，batch里的数据就必须连续，不能打乱\n",
    "             ):\n",
    "    record_dict = {\n",
    "        \"train\": [],\n",
    "    }\n",
    "\n",
    "    global_step = 0  # 全局步数\n",
    "    model.train()  # 训练模式\n",
    "    hidden = None  # 隐藏层状态\n",
    "    with tqdm(total=epoch * len(train_loader)) as pbar:\n",
    "        for epoch_id in range(epoch):\n",
    "            for datas, labels in train_loader:\n",
    "                datas = datas.to(device)\n",
    "                labels = labels.to(device)\n",
    "\n",
    "                # 模型前向计算,如果数据集打乱了，stateful=False，hidden就要清空\n",
    "                # 如果数据集没有打乱，stateful=True，hidden就不需要清空\n",
    "                logits, hidden = model(datas, hidden=hidden if stateful else None)\n",
    "                # 计算损失,交叉熵损失第一个参数要是二阶张量，第二个参数要是一阶张量，所以要reshape\n",
    "                loss = loss_fct(logits.reshape(-1, vocab_size), labels.reshape(-1))\n",
    "\n",
    "                # 反向传播\n",
    "                optimizer.zero_grad()  # 梯度清零\n",
    "                loss.backward()  # 反向传播\n",
    "                optimizer.step()  # 优化器更新参数\n",
    "\n",
    "                loss = loss.cpu().item()\n",
    "\n",
    "                record_dict[\"train\"].append({\n",
    "                    \"loss\": loss,\n",
    "                    \"step\": global_step\n",
    "                })\n",
    "\n",
    "                if save_ckpt_callback is not None:\n",
    "                    # model.state_dict() 返回模型参数的字典，包括模型参数和优化器参数\n",
    "                    save_ckpt_callback(global_step, model.state_dict(), metric=-loss)\n",
    "                    # 保存最好的模型，覆盖之前的模型，保存step，保存state_dict,通过metric判断是否保存最好的模型\n",
    "\n",
    "                # 更新进度条和全局步数\n",
    "                pbar.update(1)  # 更新进度条\n",
    "                global_step += 1  # 全局步数加一\n",
    "                pbar.set_postfix({\"epoch\": epoch_id})\n",
    "\n",
    "    return record_dict  # 训练结束，返回记录字典 record_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "80c65730aa391cf6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-24T05:12:40.763785Z",
     "start_time": "2025-01-24T05:12:02.903821Z"
    },
    "execution": {
     "iopub.execute_input": "2025-01-24T12:01:19.748944Z",
     "iopub.status.busy": "2025-01-24T12:01:19.748474Z",
     "iopub.status.idle": "2025-01-24T12:06:25.781297Z",
     "shell.execute_reply": "2025-01-24T12:06:25.780815Z",
     "shell.execute_reply.started": "2025-01-24T12:01:19.748914Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 17300/17300 [05:04<00:00, 56.88it/s, epoch=99]\n"
     ]
    }
   ],
   "source": [
    "epoch = 100\n",
    "\n",
    "# 单层单边\n",
    "model = CharLSTM(vocab_size=vocab_size)\n",
    "model.to(device)\n",
    "\n",
    "# 1. 定义损失函数 采用交叉熵损失\n",
    "loss_fct = nn.CrossEntropyLoss()\n",
    "\n",
    "# 2. 定义优化器\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "# 3.save model checkpoint\n",
    "if not os.path.exists(\"checkpoints\"):\n",
    "    os.makedirs(\"checkpoints\")\n",
    "save_ckpt_callback = SaveCheckpointsCallback(save_dir=\"checkpoints\", save_step=len(train_loader), save_best_only=True)\n",
    "\n",
    "# 训练过程\n",
    "record_dict = training(\n",
    "    model,\n",
    "    train_loader,\n",
    "    epoch,\n",
    "    loss_fct,\n",
    "    optimizer,\n",
    "    save_ckpt_callback=save_ckpt_callback,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "6aa3eaf8a1631029",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-24T12:06:25.782228Z",
     "iopub.status.busy": "2025-01-24T12:06:25.781877Z",
     "iopub.status.idle": "2025-01-24T12:06:25.909575Z",
     "shell.execute_reply": "2025-01-24T12:06:25.909136Z",
     "shell.execute_reply.started": "2025-01-24T12:06:25.782206Z"
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAikAAAGdCAYAAADXIOPgAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAWw5JREFUeJzt3Xl4U1X+BvA3SdN0oelC6d5CoezQQtksIKLs4IILKi64oI4OjDr4QwedUdAZcdwYVERccRlEUUBH2cpSFluWQguUpXSjLaULpW3SvWlyfn8kuW1oi3aB3JL38zw8ITf33px8uTRvzz3nXoUQQoCIiIhIZpT2bgARERFRcxhSiIiISJYYUoiIiEiWGFKIiIhIlhhSiIiISJYYUoiIiEiWGFKIiIhIlhhSiIiISJac7N2AP8JkMuH8+fPw8PCAQqGwd3OIiIjoDxBCoLy8HEFBQVAqW98v0ilCyvnz5xEaGmrvZhAREVEb5ObmIiQkpNXbdYqQ4uHhAcD8IbVabYft12AwYNu2bZg8eTLUanWH7bezYR1YA4A1sGIdWAOANbBqbx30ej1CQ0Ol7/HW6hQhxXqKR6vVdnhIcXNzg1ardfiD0NHrwBqwBlasA2sAsAZWHVWHtg7V4MBZIiIikiWGFCIiIpIlhhQiIiKSJYYUIiIikiWGFCIiIpIlhhQiIiKSJYYUIiIikiWGFCIiIpIlhhQiIiKSJYYUIiIikiWGFCIiIpIlhhQiIiKSJYcOKV/EZ+PHLCVSC8rt3RQiIiK6hEOHlM0pBdhToERuabW9m0JERESXcOiQYr11tEkIO7eEiIiILuXQIUVpzihgRiEiIpIfhw4p7EkhIiKSL4cOKexJISIiki8HDynsSSEiIpIrhw4plowCEzMKERGR7Dh0SLH2pAj2pBAREcmOg4cU8yN7UoiIiOTHoUMKZ/cQERHJl0OHFPakEBERyVe7Qsobb7wBhUKBZ5999rLrrVu3Dv369YOLiwsGDx6MTZs2tedtOwzHpBAREclXm0PKoUOHsGrVKkRGRl52vfj4eMyePRtz585FUlISZs6ciZkzZyIlJaWtb91hGqYg27khRERE1ESbQkpFRQXuv/9+fPLJJ/D29r7susuXL8fUqVOxcOFC9O/fH6+99hqio6PxwQcftKnBHalhCjJTChERkdw4tWWjefPmYcaMGZg4cSL++c9/XnbdhIQELFiwwGbZlClTsHHjxha3qa2tRW1trfRcr9cDAAwGAwwGQ1ua3DxLOKmvN3bsfjsZ62dnDViDxo+OinVgDQDWwKq9dWhv/VodUtauXYsjR47g0KFDf2j9goIC+Pv72yzz9/dHQUFBi9ssXboUS5YsabJ827ZtcHNza12DL+NCkRKAEidPncKmspMdtt/OKjY21t5NsDvWgDWwYh1YA4A1sGprHaqqqtr1vq0KKbm5uXjmmWcQGxsLFxeXdr3x5SxatMim90Wv1yM0NBSTJ0+GVqvtsPfZok8GLhahT99+mD42vMP229kYDAbExsZi0qRJUKvV9m6OXbAGrIEV68AaAKyBVXvrYD0T0latCimHDx9GUVERoqOjpWVGoxF79uzBBx98gNraWqhUKpttAgICUFhYaLOssLAQAQEBLb6PRqOBRqNpslytVnfowaJSmofkKJRKhz4IrTq6vp0Ra8AaWLEOrAHAGli1tQ7trV2rBs5OmDABx48fR3JysvRn+PDhuP/++5GcnNwkoABATEwMduzYYbMsNjYWMTEx7Wp4R+AUZCIiIvlqVU+Kh4cHBg0aZLPM3d0dXbt2lZbPmTMHwcHBWLp0KQDgmWeewQ033IB33nkHM2bMwNq1a5GYmIiPP/64gz5C2/FibkRERPLV4VeczcnJQX5+vvR89OjRWLNmDT7++GNERUXhhx9+wMaNG5uEHXtQWFKKAFMKERGR3LRpCnJjcXFxl30OALNmzcKsWbPa+1YdTupJMdm3HURERNSUg9+7h2NSiIiI5MrBQ4r5kWNSiIiI5MehQ4pCuncPUwoREZHcOHRIsfakMKMQERHJj4OHFPakEBERyZVDh5SG0z12bggRERE14dAhpeF0D1MKERGR3Dh4SGFPChERkVw5eEgxP3JMChERkfw4dEjhFGQiIiL5cuyQYnlkRiEiIpIfhw4pnIJMREQkXw4eUsyPHDhLREQkPw4dUhS8wSAREZFsOXRIYU8KERGRfDl4SLH0pIAphYiISG4cOqQo2JNCREQkWw4dUpQck0JERCRbDh5SzI/sSSEiIpIfhw4pvOIsERGRfDl0SFFaPj17UoiIiOTHsUOKdUwKUwoREZHsMKSAPSlERERy5NAhpWEKMlMKERGR3Dh0SGmYgmznhhAREVETDh5SzI/sSSEiIpIfhw4plozCkEJERCRDjh1SOHCWiIhIthw6pPCy+ERERPLl4CHF/MieFCIiIvlx6JDCy+ITERHJl0OHFGtPCjMKERGR/LQqpKxcuRKRkZHQarXQarWIiYnB5s2bW1x/9erVUCgUNn9cXFza3eiOomRPChERkWw5tWblkJAQvPHGG+jduzeEEPjyyy9x2223ISkpCQMHDmx2G61Wi9TUVOm59RSLHEg9KfZtBhERETWjVSHllltusXn+r3/9CytXrsT+/ftbDCkKhQIBAQFtb+EVxDEpRERE8tWqkNKY0WjEunXrUFlZiZiYmBbXq6ioQPfu3WEymRAdHY3XX3+9xUBjVVtbi9raWum5Xq8HABgMBhgMhrY2uQlhMgIAjEZTh+63s7F+dtaANWj86KhYB9YAYA2s2luH9tZPIVp5kZDjx48jJiYGNTU16NKlC9asWYPp06c3u25CQgLS0tIQGRkJnU6Ht99+G3v27MGJEycQEhLS4nssXrwYS5YsabJ8zZo1cHNza01zLyupWIHVaSpEaE34y0BTh+2XiIiIgKqqKtx3333Q6XTQarWt3r7VIaWurg45OTnQ6XT44Ycf8Omnn2L37t0YMGDA725rMBjQv39/zJ49G6+99lqL6zXXkxIaGori4uI2fciW/HI0D3/94QSGh3ni28dHddh+OxuDwYDY2FhMmjQJarXa3s2xC9aANbBiHVgDgDWwam8d9Ho9fH192xxSWn26x9nZGREREQCAYcOG4dChQ1i+fDlWrVr1u9uq1WoMHToU6enpl11Po9FAo9E0u31HHixqJ8vHVygc+iC06uj6dkasAWtgxTqwBgBrYNXWOrS3du2+TorJZLLp9bgco9GI48ePIzAwsL1v2yGUvHcPERGRbLWqJ2XRokWYNm0awsLCUF5ejjVr1iAuLg5bt24FAMyZMwfBwcFYunQpAODVV1/Fddddh4iICJSVleGtt95CdnY2HnvssY7/JG3QcFl8phQiIiK5aVVIKSoqwpw5c5Cfnw9PT09ERkZi69atmDRpEgAgJycHSmVD50xpaSkef/xxFBQUwNvbG8OGDUN8fPwfGr9yNSiUnIJMREQkV60KKZ999tllX4+Li7N5vmzZMixbtqzVjbparJeVY0YhIiKSH967B+xJISIikiMHDymW0z28RAoREZHsOHRIsV4Wv5WXiiEiIqKrwKFDSsPpHvu2g4iIiJpy8JDC2T1ERERy5dAhRcGeFCIiItly6JCi5JgUIiIi2XLwkGJ+ZEQhIiKSHwcPKRyTQkREJFcOHVI4JoWIiEi+HDqkcEwKERGRfDGkgD0pREREcuTQIUXBe/cQERHJlkOHlIbTPXZuCBERETXh0CGFPSlERETy5dAhRcmQQkREJFsOHVIUPN1DREQkWw4dUngxNyIiIvly8JBifuQUZCIiIvlx8JDCnhQiIiK5cuiQYp3dw4xCREQkPw4dUtiTQkREJF8OHlLMjxyTQkREJD8OHVIUvMEgERGRbDl0SGFPChERkXw5eEhhTwoREZFcOXhIMT+yJ4WIiEh+HDqkWMekAOxNISIikhuHDinKRiGFvSlERETy4uAhpeHvvFYKERGRvDh0SFEwpBAREcmWg4eUxmNS7NgQIiIiaqJVIWXlypWIjIyEVquFVqtFTEwMNm/efNlt1q1bh379+sHFxQWDBw/Gpk2b2tXgjsTTPURERPLVqpASEhKCN954A4cPH0ZiYiJuuukm3HbbbThx4kSz68fHx2P27NmYO3cukpKSMHPmTMycORMpKSkd0vj24sBZIiIi+WpVSLnlllswffp09O7dG3369MG//vUvdOnSBfv37292/eXLl2Pq1KlYuHAh+vfvj9deew3R0dH44IMPOqTx7aWwCSlMKURERHLi1NYNjUYj1q1bh8rKSsTExDS7TkJCAhYsWGCzbMqUKdi4ceNl911bW4va2lrpuV6vBwAYDAYYDIa2NrkJY33DvurqDDCoOmzXnYq1ph1Z286GNWANrFgH1gBgDazaW4f21q/VIeX48eOIiYlBTU0NunTpgg0bNmDAgAHNrltQUAB/f3+bZf7+/igoKLjseyxduhRLlixpsnzbtm1wc3NrbZNbZD7F42TZdyzc1R22604pNjbW3k2wO9aANbBiHVgDgDWwamsdqqqq2vW+rQ4pffv2RXJyMnQ6HX744Qc89NBD2L17d4tBpS0WLVpk0wOj1+sRGhqKyZMnQ6vVdtj71NXVAfvjAAA3TZyIru7OHbbvzsRgMCA2NhaTJk2CWu2YSY01YA2sWAfWAGANrNpbB+uZkLZqdUhxdnZGREQEAGDYsGE4dOgQli9fjlWrVjVZNyAgAIWFhTbLCgsLERAQcNn30Gg00Gg0TZar1eordrCoVE4OfSACV7a+nQVrwBpYsQ6sAcAaWLW1Du2tXbuvk2IymWzGjzQWExODHTt22CyLjY1tcQyLPShgHjDLe/cQERHJS6t6UhYtWoRp06YhLCwM5eXlWLNmDeLi4rB161YAwJw5cxAcHIylS5cCAJ555hnccMMNeOeddzBjxgysXbsWiYmJ+Pjjjzv+k7SRQmG+kBunIBMREclLq0JKUVER5syZg/z8fHh6eiIyMhJbt27FpEmTAAA5OTlQKhs6Z0aPHo01a9bg73//O1588UX07t0bGzduxKBBgzr2U7SDdRIypyATERHJS6tCymeffXbZ1+Pi4posmzVrFmbNmtWqRl1NSgBGAIwoRERE8uLQ9+4BGm4yaOL5HiIiIllhSLE88mwPERGRvDCkWHtSmFKIiIhkhSHF8siQQkREJC8MKZZHDkkhIiKSF4YUS0rhxdyIiIjkhSHF8sieFCIiInlhSOHAWSIiIlliSLE8MqQQERHJC0OKNCbFvu0gIiIiWwwplkf2pBAREcmLw4cUpTQmxb7tICIiIlsOH1LYk0JERCRPDCmWR14nhYiISF4YUni6h4iISJYYUiyPJqYUIiIiWWFIsU5Btm8ziIiI6BIMKZZHDpwlIiKSF4YUyyMzChERkbwwpPDePURERLLEkGJ55LhZIiIieWFIYU8KERGRLDGkWB55MTciIiJ5cfiQIt27x2TfdhAREZEthw8pnIJMREQkTwwpvCw+ERGRLDGkWB45JoWIiEheGFIsj+xJISIikheGFIU5nXBMChERkbwwpFgeGVKIiIjkhSHFehdkZhQiIiJZcfiQYi0Ae1KIiIjkpVUhZenSpRgxYgQ8PDzg5+eHmTNnIjU19bLbrF69GgqFwuaPi4tLuxrdkTgFmYiISJ5aFVJ2796NefPmYf/+/YiNjYXBYMDkyZNRWVl52e20Wi3y8/OlP9nZ2e1qdEfimBQiIiJ5cmrNylu2bLF5vnr1avj5+eHw4cMYN25ci9spFAoEBAS0rYVXCzMKERGRrLQqpFxKp9MBAHx8fC67XkVFBbp37w6TyYTo6Gi8/vrrGDhwYIvr19bWora2Vnqu1+sBAAaDAQaDoT1NtmEwGKR79xjq6zt0352J9XM76ucHWAOANbBiHVgDgDWwam8d2ls/hWjjpVZNJhNuvfVWlJWVYd++fS2ul5CQgLS0NERGRkKn0+Htt9/Gnj17cOLECYSEhDS7zeLFi7FkyZImy9esWQM3N7e2NLdFn55W4nipEvf0NGK0P7tTiIiIOkpVVRXuu+8+6HQ6aLXaVm/f5pDy1FNPYfPmzdi3b1+LYaM5BoMB/fv3x+zZs/Haa681u05zPSmhoaEoLi5u04e8XFvu/WAHjpUoseSW/rhvZGiH7bszMRgMiI2NxaRJk6BWq+3dHLtgDVgDK9aBNQBYA6v21kGv18PX17fNIaVNp3vmz5+PX375BXv27GlVQAEAtVqNoUOHIj09vcV1NBoNNBpNs9t29MFiHTirVCod+kAErkx9OxvWgDWwYh1YA4A1sGprHdpbu1bN7hFCYP78+diwYQN27tyJ8PDwVr+h0WjE8ePHERgY2OptrwROQSYiIpKnVvWkzJs3D2vWrMFPP/0EDw8PFBQUAAA8PT3h6uoKAJgzZw6Cg4OxdOlSAMCrr76K6667DhERESgrK8Nbb72F7OxsPPbYYx38UdqGU5CJiIjkqVUhZeXKlQCA8ePH2yz/4osv8PDDDwMAcnJyoFQ2dNCUlpbi8ccfR0FBAby9vTFs2DDEx8djwIAB7Wt5B+FdkImIiOSpVSHlj4yxjYuLs3m+bNkyLFu2rFWNupoa7t3DlEJERCQnvHeP5ZGne4iIiOTF4UMKB84SERHJE0OK5ZE9KURERPLCkCKNSbFvO4iIiMgWQ4rl0cTzPURERLLCkMIxKURERLLEkGJ55JgUIiIieWFIsTzyOilERETywpBiHThr32YQERHRJRhSLI883UNERCQvDCkcOEtERCRLDCmWR/akEBERyYvDhxRrAZhRiIiI5MXhQ4p0uofne4iIiGSFIcXyyIxCREQkLwwp0sBZphQiIiI5YUixPPJibkRERPLCkMIpyERERLLEkGK51ixP9xAREcmLw4cUJXtSiIiIZMnhQwrHpBAREckTQ4olpRjZlUJERCQrDh9SnCwhpbbeZN+GEBERkQ2HDykuKvNjZW29fRtCRERENhhSLCGlnCGFiIhIVhhSLCGlooYhhYiISE4cPqRoVOYBs5V1DClERERy4vAhhT0pRERE8sSQ4mR+5JgUIiIieXH4kKKxVKCu3oQ6TkMmIiKSDYcPKdbTPQCnIRMREcmJw4cUlRJwUZvLUMGQQkREJButCilLly7FiBEj4OHhAT8/P8ycOROpqam/u926devQr18/uLi4YPDgwdi0aVObG3wluDubB6YwpBAREclHq0LK7t27MW/ePOzfvx+xsbEwGAyYPHkyKisrW9wmPj4es2fPxty5c5GUlISZM2di5syZSElJaXfjO0oXDUMKERGR3Di1ZuUtW7bYPF+9ejX8/Pxw+PBhjBs3rtltli9fjqlTp2LhwoUAgNdeew2xsbH44IMP8NFHH7Wx2R2ri2VgCkMKERGRfLQqpFxKp9MBAHx8fFpcJyEhAQsWLLBZNmXKFGzcuLHFbWpra1FbWys91+v1AACDwQCDwdCOFtuy7stNbQ4pusraDt1/Z2H9zI742a1YA9bAinVgDQDWwKq9dWhv/docUkwmE5599lmMGTMGgwYNanG9goIC+Pv72yzz9/dHQUFBi9ssXboUS5YsabJ827ZtcHNza2uTW1StLwGgREJiEhS5osP331nExsbauwl2xxqwBlasA2sAsAZWba1DVVVVu963zSFl3rx5SElJwb59+9rVgOYsWrTIpvdFr9cjNDQUkydPhlar7bD3MRgMiI2NRY+QQKSUFqJnn/6YPqZHh+2/s7DWYdKkSVCr1fZujl2wBqyBFevAGgCsgVV762A9E9JWbQop8+fPxy+//II9e/YgJCTksusGBASgsLDQZllhYSECAgJa3Eaj0UCj0TRZrlarr8jB4uFi3meVQTj0wXil6tuZsAasgRXrwBoArIFVW+vQ3tq1anaPEALz58/Hhg0bsHPnToSHh//uNjExMdixY4fNstjYWMTExLSupVcQZ/cQERHJT6t6UubNm4c1a9bgp59+goeHhzSuxNPTE66urgCAOXPmIDg4GEuXLgUAPPPMM7jhhhvwzjvvYMaMGVi7di0SExPx8ccfd/BHaTtrSOEVZ4mIiOSjVT0pK1euhE6nw/jx4xEYGCj9+e6776R1cnJykJ+fLz0fPXo01qxZg48//hhRUVH44YcfsHHjxssOtr3auljuMsibDBIREclHq3pShPj9mS9xcXFNls2aNQuzZs1qzVtdVe7O5inI7EkhIiKSD4e/dw/QaExKDUMKERGRXDCkoOF0DwfOEhERyQdDChpO9zCkEBERyQdDChpO9+irHfvyx0RERHLCkALAy818sZny2noYTY57WXwiIiI5YUgB4OlqDilCsDeFiIhILhhSAKhVSnhYTvmUVtXZuTVEREQEMKRIPC2nfEqr2JNCREQkBwwpFt5uzgCAMvakEBERyQJDioV18GwZe1KIiIhkgSHFwtqTwjEpRERE8sCQYsGeFCIiInlhSLHwYk8KERGRrDCkWHhbe1J4nRQiIiJZYEix4OweIiIieWFIsZCuk1LJnhQiIiI5YEixYE8KERGRvDCkWHjzirNERESywpBi4eVq7kmpNhhRYzDauTVERETEkGLh4eIEpcL8dx1n+BAREdkdQ4qFUqngtVKIiIhkhCGlkQCtCwDgdH65nVtCREREDCmNTOjvBwDYnJJv55YQERERQ0oj0wYFAgDiUi+gsrbezq0hIiJybAwpjfQP9ECPrm6orTdh5+kiezeHiIjIoTGkNKJQKDB5YAAAID6j2M6tISIicmwMKZeIDPEEAJw8r7dzS4iIiBwbQ8olBgaZQ8rpgnLUG012bg0REZHjYki5RHcfN7g7q1Bbb0JmcaW9m0NEROSwGFIuoVQq0D9QCwA4cV5n59YQERE5LoaUZgwMsoSUPI5LISIisheGlGZYx6UcySmFEMLOrSEiInJMrQ4pe/bswS233IKgoCAoFAps3LjxsuvHxcVBoVA0+VNQUNDWNl9xI8N9oFIqcCSnDB/GZdi7OURERA6p1SGlsrISUVFRWLFiRau2S01NRX5+vvTHz8+vtW991fTwdcfiWwcCAN7elopzpVV2bhEREZHjcWrtBtOmTcO0adNa/UZ+fn7w8vJq9Xb28uB13fHD4XM4mluGQ2dLEOLtZu8mEREROZRWh5S2GjJkCGprazFo0CAsXrwYY8aMaXHd2tpa1NbWSs/1evMAVoPBAIPB0GFtsu6rpX1Gh3riaG4ZErNKcPMg/w57X7n5vTo4AtaANbBiHVgDgDWwam8d2ls/hWjHyFCFQoENGzZg5syZLa6TmpqKuLg4DB8+HLW1tfj000/x9ddf48CBA4iOjm52m8WLF2PJkiVNlq9ZswZublevRyOpWIHVaSqEuAssjDRetfclIiK6FlRVVeG+++6DTqeDVqtt9fZXPKQ054YbbkBYWBi+/vrrZl9vriclNDQUxcXFbfqQLTEYDIiNjcWkSZOgVqubvJ6vq8G4t/dApVTgf/Ni0KOrG9Sqa29C1O/VwRGwBqyBFevAGgCsgVV766DX6+Hr69vmkHLVTvc0NnLkSOzbt6/F1zUaDTQaTZPlarX6ihwsLe03zFcNf60GhfpaTH8/HrdGBeG92UM7/P3l4krVtzNhDVgDK9aBNQBYA6u21qG9tbNLt0BycjICAwPt8datNqKHj/T3HacKed0UIiKiq6TVPSkVFRVIT0+XnmdlZSE5ORk+Pj4ICwvDokWLkJeXh6+++goA8J///Afh4eEYOHAgampq8Omnn2Lnzp3Ytm1bx32KK+jF6f3Ro6s7PtiVjso6I86VViPUhzN9iIiIrrRWh5TExETceOON0vMFCxYAAB566CGsXr0a+fn5yMnJkV6vq6vDc889h7y8PLi5uSEyMhLbt2+32YecBXm54v+m9MXuMxdwPE+H43k6hhQiIqKroNUhZfz48Zc95bF69Wqb588//zyef/75VjdMbiJDPHE8T4cXfjyG93akYdWDw9C9q7u9m0VERHTNuvamqlwhkSHm+/mU19TjdEE5VseftW+DiIiIrnEMKX/Q4GAvm+dHcsrs0g4iIiJHwZDyB/Xx74JATxfp+fFzZdBVOfaVCImIiK4khpQ/yEmlxP/+MhYHXpyAXt3cYRJAQmaxvZtFRER0zWJIaQXfLhr4a10wNsIXALDzdJGdW0RERHTtYkhpg2mDzReiW38kDxkXKuzcGiIiomsTQ0obXNezK27q54d6k8Ct7+/D2H/vRG5Jlb2bRUREdE1hSGmjl2b0h7NKKV2F9rN9WfZuEhER0TWFIaWNenXrgk3PjMW8G3sBANYl5uJgVgmKK2p/Z0siIiL6IxhS2iHCzwPPTeqLnr7uqKwz4u5VCbjjw3jU1hvt3TQiIqJOjyGlnZRKBZ68oZf0PKekCj8ezrNji4iIiK4NDCkd4O4RoYj7v/F4YWo/AMCHcekwGE12bhUREVHnxpDSQXr4uuPh0T3g28UZ50qrsWp3hr2bRERE1KkxpHQgV2cVXprRHwCwfEcaTuXrAQBCCI5TISIiaiUnezfgWjNzSDB+PZaP7aeKMHf1IdwRHYJfj+cj+2IlPn1oOLzdnGEwCowM97F3U4mIiGSNIaWDKRQKvHlXFO76KB6ZFyrxwa506bWXfzqBfF0NjCaBN++KxN3DQ+3YUiIiInnj6Z4rwMfdGV89OhIjw30wsb8/3rwrEm7OKpwrrYbRJAAAi9Yfx54zF+zcUiIiIvliSLlCQrzd8P2fYvDpQ8Nx9/BQzB4ZBgBwUSsxZaA/jCaBp9cm8XL6RERELWBIuUrm3RiBaYMC8OZdUXhv9lBEhXiirMqApZtP2btpREREssSQcpX4uDtj5QPDcGtUEDROKrx+x2AAwI5TRSjQ1eBA5kWkFZZDCCFtU6SvwV0r47EyjtOZiYjI8XDgrJ0MCNSip687Mosrcd3SHdLy24YE4Z1ZUXBSKfH02iQkZpciMbsUT43vdZm9ERERXXvYk2InCoUCMyIDpeeuahWclAr8lHweT35zBL8ey8f+zBLpdX2NwR7NJCIishuGFDtqHFI+enAYVtwfDSelAttPFWLemiM266YXVVzt5hEREdkVT/fYUb8ALd6eFQU3ZxVu6NMNALBx3hj889eTyCquxI19/XDsnA4n8/VIL6pAdJh3k33UG0147KtEBHq6YOkdkVf7IxAREV0xDCl2dtewEJvng4I9sfaJGOn5yz+l4GS+HhmWnpRzpVXo5qGBxkkFADhdUI64VPP1Vl6+eSBcnVVXqeVERERXFk/3yFyEXxcAQMaFCny2Lwtj/70Li348Lr1+prBc+vu5Ul5zhYiIrh0MKTIX0c0cUrafKsJrv5wEAKxPyoPJcuXaM4UNY1VyGVKIiOgawtM9MmftSbnU1/uzkVZUjtSChp6UnIsMKUREdO1gSJG5bh4aBHm64LyuBs9N6oODZ0uwN60Yr/x8osm6OSXVdmghERHRlcGQInMKhQJrn4iBrtqAwSGeeH9HGvamFTe7bg7vA0RERNcQhpROIKyrm/T3EeE+La6XkqfD8u1pCPZ2xcT+fvByc74azSMiIroiGFI6mSGhXvDtokFdvREeLmrklTWc4inQ12DZ9jMAgEHBWvzyl+vt1UwiIqJ2a/Xsnj179uCWW25BUFAQFAoFNm7c+LvbxMXFITo6GhqNBhEREVi9enUbmkoA4KJW4Ze/jMWWZ8fh+ydjMLG/H754eEST9U6c16O6zmiHFhIREXWMVoeUyspKREVFYcWKFX9o/aysLMyYMQM33ngjkpOT8eyzz+Kxxx7D1q1bW91YMgvwdEGQlyuCvVzx6UMjcGM/P+m1YC9X+Lg7QwjztVUAoLSyDgajyV7NJSIiapNWn+6ZNm0apk2b9ofX/+ijjxAeHo533nkHANC/f3/s27cPy5Ytw5QpU1r79tSCh2K64+ej5/HFIyPw940pOJhVgjOF5RACmLUqHhP6+2PFfdH2biYREdEfdsXHpCQkJGDixIk2y6ZMmYJnn322xW1qa2tRW1srPdfr9QAAg8EAg6Hj7gZs3VdH7tNe/j69LxZN7QOVUoGIbm44mFWC1Hw9lm9PQ43BhF+P5WPZXXVQKBSoN5pw9ycHoQDw/ROjYDLWA7g26tBW19Kx0FasgRnrwBoArIFVe+vQ3vpd8ZBSUFAAf39/m2X+/v7Q6/Worq6Gq6trk22WLl2KJUuWNFm+bds2uLm5NVneXrGxsR2+T3uqLVIAUOH7A5korVNIy7/ZsBldXYBMPXA8z/xP/9+Nm+HrYn79WqtDW7AGrIEV68AaAKyBVVvrUFXVvktjyHJ2z6JFi7BgwQLpuV6vR2hoKCZPngytVtth72MwGBAbG4tJkyZBrVZ32H7tzSezBD9+kWgTUAAgqP9wTOjvh/d3ZgAnMgAAYYNGYEy41zVZh9a4Vo+F1mANzFgH1gBgDazaWwfrmZC2uuIhJSAgAIWFhTbLCgsLodVqm+1FAQCNRgONRtNkuVqtviIHy5Xar730D/ayeT4k1AvJuWVIL67CVLUaCVkl0mvZJTUY38f82a+1OrQFa8AaWLEOrAHAGli1tQ7trd0Vv8FgTEwMduzYYbMsNjYWMTExV/qtHVZX94aLuD0xriemDgoAAJwqKEdFbT2Scsqk1zOLK69284iIiP6QVoeUiooKJCcnIzk5GYB5inFycjJycnIAmE/VzJkzR1r/ySefRGZmJp5//nmcPn0aH374Ib7//nv89a9/7ZhPQE0oFAqsenAYnr4pAv83uS/6BXgAAE7n67FiVzrqLXdQBoDMCxWoqzfhk9NKvPrLKXs1mYiIqIlWh5TExEQMHToUQ4cOBQAsWLAAQ4cOxcsvvwwAyM/PlwILAISHh+PXX39FbGwsoqKi8M477+DTTz/l9OMrbMrAACyY3BfOTkr0CzCP48m4UImVceaxKLOGhQAAMi9UYl/GRaSUKvH1gVzUGHgBOCIikodWj0kZP348hBAtvt7c1WTHjx+PpKSk1r4VdRB/rQY9u7kj80IlnFVK/OPm/rhtaDDWHT6HovJaHMhsGKNyrrQaEX5d7NhaIiIiM1nO7qGOpVAosOHPY5BbUoWwrm7QupgHMvl20aC4ohbfJZ6T1s0trWJIISIiWWBIcRCermp4BnvaLOvt1wXFFbWobHSPn9yS9s1pJyIi6ihXfHYPydcT43o2WcaQQkREcsGQ4sBu7OeHe0eE2izLKamCEALfH8rF3NWH8Jdvk3hzQiIisgue7nFw/5w5CDf07orfDh7GN+kq5JZU42BWCZ7/8Zi0zuyRoRjdyxe7UougqzJg5tBgO7aYiIgcBUOKg3NSKTGxvx9yTphnbOWWVOHHI+ds1skoqsDAQE888sUhAMCIcB8EezV/tWAiIqKOwtM9BADwsdyFoLy2Ht9bZvsMDfMCAKQXVWDbyQJp3XMct0JERFcBQwoBAJxVgJ9Hw/2Sgr1cMXtEGAAg/UIFNh3Pl147r6u+6u0jIiLHw5BCkjG9fKS/P3Bdd0T4m6+XciS7DPvSi6XXzpfVXPW2ERGR4+GYFJL8+45BeHZSXxhNAuG+7iivrQcAVF9yqfzzZeaelPj0Yhw9p4NvF2fcER0ClVJx1dtMRETXLoYUkigUCnTv6i4917qopavSAsCAQC1O5uuRr6tBbkkVHvz8IIyWmxV6uzlj4gB/pBeV42R+OW6NCrLLZyAiomsHT/fQZVkDCgDMvykCgLkn5bN9WVJAAczjVgDgz/89gqe/TcLh7BIQERG1B0MKXdaTN/QCYL5rsvWePqcLyvHdoVwAQFSoFwBzcDlfVo0zheawkl5UcfUbS0RE1xSe7qHLmn9TBIaGeeGmfn6oaTQ2pdpgxIBALWYNC8HR3DLklVYjPuOi9HoeB9cSEVE7MaTQZXXROGHKwAAAgFqlhIfGSRpQ+2BMdwR4ugAA8sqqEd9oBlBeKacpExFR+zCkUKtYAwoA3BIVhHzLTJ+8smqUVtVJr1lnAFmZTAJKzv4hIqJW4JgUahWtiznX9ujqhi4aJwRZLo9fXlOPQn3DINvGF3xbsSsdkUu2YW/ahavbWCIi6tQYUqhVPrx/GMb37Yav544CALhrnODlppZe79HVDQCQX1YDk0ng24M5eGtrKipq6/Fz8nm7tJmIiDonhhRqlbG9fbH6kZEI9XGTljW+2eDtQ0OgVAB1RhOO5emw+OcT0mvnOE6FiIhagSGF2i2oUUgZ18cXAVrzYNq/fpeM2noTXNTmwyyN05KJiKgVGFKo3aw9KR4aJwwO9kSwt/l5VnElAODzh0cAMF8YTldlwKbj+Zi54jfk8m7KRER0GQwp1G49u5kvpR/TqyucVEq4OTdMGps9Mgyje/ki0DJV+XieDn/+7xEk55bho90ZdmkvERF1DpyCTO129/BQVNUZpfv1DA3zwu4zF6BQAK/cMgAAEOHXBfm6Gvzz15PSdo2nKZdW1uHTfZm4IzoEvbp1ubofgIiIZIkhhdrNRa2SLp8PAA/F9IDGSYXbhwbDRa0CYA4pe9OKcbqgXFqv8d9nf7IfpwvKcSq/XDo9REREjo0hhTqct7sznhrfy2aZ9b4/ABDq44rckmrk62pwsaIWJ/P1UmDZebroqraViIjki2NS6Kq4rmdXOCkViArxxM/zxqKnr3kcy770YrzwwzFpPSelAnX1Jns1k4iIZIQhha6KXt264PDfJ2HDn8fA290ZA4M9AQDPrE3GeV2NdBG4epNA9sVKFFfUYsWudBSVN9yosMZgxD9/OYl9acXNvgcREV1bGFLoqvF0U0v37xkYpJWWuzmrsOrB4RgS6gXAPFblqW8O462tqVgZ1zADaPmONHy6LwsPfHbgqrabiIjsg2NSyC5uGxKEzcfz0cPXHfNvjEBvfw/08e+C5Nwy/HvLaenqtKfy9dI2jS+rbzQJqHjDQiKiaxpDCtlFoKcrfpo/1mZZbz8PALaXz0+3XKW2qLwGeY2mLJ8rrUL3ru5XoaVERGQvPN1DstHbv2EG0PW9fQEAxRV1KKmsw6Zj+TbrZlzgJfaJiK51DCkkG4OCPeGiVqKruzPeuTsKIZbL6x89V4aPdmfarJt5oRIpeTrcuTIePyXn2aO5RER0hbUppKxYsQI9evSAi4sLRo0ahYMHD7a47urVq6FQKGz+uLi4tLnBdO3y7aLBpqevx+Znroefhwt6W66tsuC7ZBToaxDq44o/jesJAIg9WYib39+Hw9mleHNLqj2bTUREV0irQ8p3332HBQsW4JVXXsGRI0cQFRWFKVOmoKio5YtwabVa5OfnS3+ys7Pb1Wi6dvXs1gV+lrso9/E3j1EprTIAAF6aPgADLLOCDmSVSNvklVWjxmBscZ+7Uot4eoiIqBNqdUh599138fjjj+ORRx7BgAED8NFHH8HNzQ2ff/55i9soFAoEBARIf/z9/dvVaHIMvRpdpXbyAH9MGeiPnr4NyzxcGsZ9p+TpUFFbjw92ptncXXnT8Xw88sUhPPZl4tVpNBERdZhWze6pq6vD4cOHsWjRImmZUqnExIkTkZCQ0OJ2FRUV6N69O0wmE6Kjo/H6669j4MCBLa5fW1uL2tpa6bleb56GajAYYDAYWtPky7LuqyP32RnJtQ6DAs2BRKVU4NVb+6O+vh6hXs7S6/PH98TBs6XYcfoCjmSX4H9H8/BlQg5O5Onw3r1RAIBlseZTQVnFlaiuqYWTqvlcLtcaXE2sgRnrwBoArIFVe+vQ3vophBDij658/vx5BAcHIz4+HjExMdLy559/Hrt378aBA00vspWQkIC0tDRERkZCp9Ph7bffxp49e3DixAmEhIQ0+z6LFy/GkiVLmixfs2YN3Nzc/mhz6RqQrgd8NOY/VtvOKaCvU+D2cBN25Cnwa64KA71NyNQrUG1UwN1J4J/DjbhQA7ye3JDD/zG0Hr4cDkVEdNVUVVXhvvvug06ng1ar/f0NLnHFr5MSExNjE2hGjx6N/v37Y9WqVXjttdea3WbRokVYsGCB9Fyv1yM0NBSTJ09u04dsicFgQGxsLCZNmgS1Wt1h++1sOlsdpjf6e9eMi/h19WGcKG3oIamsV6DP8OuRkJADoGHmT3jkSFwf4dvsPpOzS3Do4H48PLNz1OBK6GzHwZXCOrAGAGtg1d46WM+EtFWrQoqvry9UKhUKCwttlhcWFiIgIOAP7UOtVmPo0KFIT09vcR2NRgONRtNkuVqtviIHy5Xab2fTGesQ3aMrXNUqVFsGzjopFag3CaxPLsCPSeYr1HZ1d8bFyjrkldVCrVZDCAGFouFqtWeLKzHr00Q4K1V46HZVp6tBR+uMx8GVwDqwBgBrYNXWOrS3dq0aOOvs7Ixhw4Zhx44d0jKTyYQdO3bY9JZcjtFoxPHjxxEYGNi6lhI1w8NFjfV/Ho2FU/ri6Zsi8OfxvQAAn+3LgtEkMK5PN9w+NBgAsOZgLiYv242+/9iCpZtPSftYHX8WAFBnUiCzuPKqfwYiImpeq2f3LFiwAJ988gm+/PJLnDp1Ck899RQqKyvxyCOPAADmzJljM7D21VdfxbZt25CZmYkjR47ggQceQHZ2Nh577LGO+xTk0PoHajHvxggsmNwX4/p0k5arVQo8N6kPuvuaL59/Kl+PM4UVqKs34YvfzkJXZUB5jQE/HD4nbZNaUN7i+/zr15N4/KtE1NWbrtyHISIiSavHpNxzzz24cOECXn75ZRQUFGDIkCHYsmWLNK04JycHSmVD9iktLcXjjz+OgoICeHt7Y9iwYYiPj8eAAQM67lMQWUSHeeNPN/RETZ0Rs0eFoV+AFvqahtHlCgWgUihQV2/CT0fzkJRThoraeun104UVKK2sw8NfHET/QC3euDMSAHDyvB6f7M0CABw6W4IxLYxtISKijtOmgbPz58/H/Pnzm30tLi7O5vmyZcuwbNmytrwNUasplQosmtbfZll3n4YbEV4X3hUT+vvhn7+ewss/nQBgHscyeYAfNqUUIrWgHP/4KQVHz+lw9JwOf5vWD15uzvj8tyxpH8m5ZRgT4YujuWUwGE0Y3sPn6nw4IqJLLP75BPZnXsR3T8TA06358R/ny6pxrrQaI8N9IITAa7+cws7Thfj2iesQ6OmKpJxSpJzX48Hrul/l1v8+3gWZrnlBXg3zjq1jVN7ckoo6owlKBfD67YPRo6sLNqUUYm/6RZttE8+WYnCIJ35OPi8tS8opMw+2/SgBAgK/vXAT/LQuKNLX4P2d6cgtrcKyu4fA290ZmRcqkF1ShRv7+tnst8ZghItadWU/OBF1ajUGI5QKBZydzGcn6o3mU83W6z2lF5VLY+piTxXirmFNL+vx2i8n8cVvWTAJ4KMHopFZXCn90vVT8nlMGxSAx75MxMXKOrg4KTFreOhV+GR/HG8wSNc8J5US948KQ1SoFx64Lgxdu2iw+tER+NftgxD3fzfi7hGh6NPo6raNHTxbgld+OoE6ownWCUHJuaVYuvkU6owmGIwCO04XoaquHjNX/Iav92cjLvUCfjxyDgajCQ98egCPfHEIR3JKpX2u2JWOfv/Ygr1pF1psc2llHd7ZlipdPbfeaELi2RIYTX/4skZE1E5GAWw7WYii8poO33deWTXKqupafL26zoib39+HG9+OQ43BiOo6I256ZzduW/Gb9HPgs31npfVT8nQAgKLyGmRabgNyukCPz/aZAwoAfLo3C29vbbjX2dYTBXhk9SFcrKzDwCAtpg+W34QWhhRyCP+6fTB+mjcGHi7m7tDRvXxx/6juCOtqvjigu6ahU9FJqcDrtw8GAHy8JxNbThTASanA+qdGQ61SoLiiDltPNEzD336yEAcyS3Be1/CDLC71AraeKJCW7TxlvrdVcUUt3rL8kPh0b8MppEst+d8JvL8zHW9vM6/7/s503PVRgvRbExG17KfkPNy24jdkX2x5tl5JZR3e25GGC+Xmq5tnXqjA4p9PoLjC/NxgNGH1GSXmfXsUi348DgDQVRnwn+1nLrtfwBwYdJZ7jmVfrMSTXx/GrtMN97fbc+YCxr+1C7M+SoCuyoDl29Mwd/UhfH8oV1rny4SzSC+qQF5ZNY7klGL3mQvIKanCifN6xGcUo6i8BuuPNAz6T8opxd60C5jw9m5MXb4XORersNFyGQYfd/OVuhOzS2ESgL9WY9mmDJkXKhGgdcHnD4+w+TkoFwwpRBa39zCif4AHtjx7Pcb1sR0Y++fxvTA0zBsh3g1XPJ7Y3zxYfF96MWJPmUPLyHDz+JQDWRfx4a4MaV1rr0njYGId0JtxoQJfxp+FwdKVm3mhAj8fNf9w2Z95EUIIrE86Z7MfInupN5rwxubT+Ck57/dXthBC4JM9mdiQdO73V/4duioDzhS2PAsvq7gSz6xNxtHcMnzx21kAwKd7MzHp3d3IK6uW1lu0/hjejT2DV385CQB4/odjWB1/Fu/vSAMAvLs9HcdKzF+RO04XobzGgMe/SsR/tqdhyf9Otvj+6xJzcfP7+/Dsd0koKq/Bze/tw5YTBXjhx2MwmQTSiyrw2JeJMBgF0ooqsOD7ZCzbfgY7Thdh8f9OoLrOCH2NASvjGn5+HMoqxbYTBdLzDUl5WBZ7BrX1JnTzMAeOo+d0ePCzgyivrUddvQnfHMjGz5Z/o8W3DrS519mCSX0Q5tPws+z5qX3hr5Xn5bgZUogsxgcK/DwvBhF+HgjxdkOojysAYNqgADwzsQ8A4LYhQQCAh0f3wMcPDkOwlytq601YcyAHAPDgdd0R6uMKg1HgZL4eKqX5HNGxPB0Sz5bgi0YDcE/l67EvrRgT3tmNV34+gXWJ5h/gH+xMl7pnC/W12JJSgNwS8w/XY+d0uNydLHRVhsu+DgDnSqtQ2WhGE1FrbD9VhI92Z+DF9cf/8OnHo+d0+NemU1i47hiq61q+Y3nGhQrptAUAlNcYUFXXcKwKIfDI6oOY+p89OJxtPoWaW1IlnU6tqzfhue+TpfWP5+lQYzBi+Y40pBVVSGPLknPLpN7QLSn52HqiAImW/W05UYDqOiO+S7QNVO9sO4ODZ813X995uggmk8ChsyW475P90vvnXKzCwh+OAQB2pV7Ai+uPo9zyf62ovBb7My/io90ZqDM2XMZgR6Melqo6I3acLsQnezKhqzZIp5gTMott1lt/JA/fHjT3unx4f7TUU9LYx3sycV5XAw+NEyYP8Mf1vc2/eDk7KTFtcCBusFyuIcKvC24bEtzsv4ccMKQQteD92dH458xBeH/2UClsPH1Tb8T/7SYsvnUglEoFZg23Hag2JsIXNzUaJPvC1L7oF+ABIYCHvziE2noTxkb4wlWtQo3BhAc+a7jf1fZThcgqrsRGy28/AZbfbKy/6QHmLupzpdWISy3CPasScLbRxee+TjiLoa9tw5uW00lCCKTk6aQeGgA4nF2C8W/FYf6aIy1+bqNJ4IfD52x6bdKLyqVucavvDuVgxa70y4ais8WV2JJS0OLrzRFCXPaLTI7qjSbM/ng/Zq74zabeHSG1oBz/PZANkyUQCCGu2Ngkk0ngx8PnkNPoTuKX+vmo+fisrDMiragcmRcqcNPbcTYXSMwtqUKNoeHfcKulF6DeJHAyX4fmrEvMxdT/7MEdH8ajQFeD4opaTHx3NyYv24PaevO+knLLcCSnDCYB/PdANsprDLhjZTzu+DAe/zt6Hn9bfwxHcsqkfR47V4bNKfkorzEHhYNZ5p7JpZsa2mowCvzp68PS80J9Lf695TTKa+rh7Szw6GjzjJdLT7VuO1mIx75MRHzGRfzT8n/031tO26yz3XKad1Cw+XYun+7Lwq/H8gFACg2AOTg8PLoHAGDV7kx8ts/8y8xfLb8c7c8sga7agK7uzjY9IDMiAzGihw+6dWm4QvuBFyfAt9HzR8eGw0Wtws2R5l+wbo4MhNZFjT/d0BMzIgPx7t1R0s83OWJIIWrBkFAvPHBdd5s7JyuVCgR5uUrP/zw+Al0s53F9uzjDx90Zj13fE7dEBeGLh0fgiXG9pNNCFbX18HRV4527ozAwqOk9qOIzivHOtlSYBHBTPz/cOcz8202+ZVyL9beqxOwSvLj+OA5klUg/ODcdz8fLP5+ASQA/HD4Hk0ngg53puPn9fVgWe0Z6j4XrjqHeJLAr9QJqDEYUlddg0frjUiCpMAB3rTqA/1t3FE98dRg1BiPOFJZj2vK9ePCzA1IgKdDV4G/rj+OtralIyWv+3hxCCDy6+hCe/OYw4jOK/1DNhRD4y7dJiFyyFSfPX/6eH6WVtoMOTZd8cacVluPNLaeRc7HlL9zqOiPe3ZaKg1kl0rL6S0JGWVUd5q85gg92mk8DGE0CK3alY8+ZhhD36/F8JGReRHJuGU6c18NgNOHJrw/jue+PthjihBB4d1sqFq472uQ9G68zf80RvLQhBT8cOYfyGgMe+OwARr2+w2Yw58WKWqxLzL3shQZLK+uwancGCvU10r4zLlTYtO+/B7Lx3LqjeGF9SrP7KK8xSF+8AHAoqwTPrE1GZnElPtubhYsVtfjx8Dlc/+YuvLjhuLTe1kanKpJzzb2BB7NKcK7U/G/zU3IeFv5wDAajQJ3RhB2nC/H6r6dQqK/FudJqJJ4191R8k5At7WfT8Xz8e8tpKTz/5dskrD+SB5VSgdWPjIC/VgODUWDxzw0hP/FsKb47lIsDWSVwUSsx/8YI6TVXtQrDunsDaAgkI7sJTOzf8EtHhF8XKXA8+c1h6KrNp2yP5JThl2PnscXyOSNDPKVtQn1csfR28/WWdp4uQrXBiJ7d3DGv0XuP6dUV94wwz6o5nqdDVZ0RUSGeNusAwO1Dg/Gfe4fgsbHh+Pedg7Hs7iEAgGcm9oarWoV3746Cv9YFr9wyAGE+bnht5iA8O7E3AGD64EBsnDcG/5ppHm8X4u2GFfdFIzLEC3Imv1EyRJ2Is5MSP88fg8X/O4kHRoUBAEJ93PD+7KHSOvNujEBv/y4o1NdgdC9f+Gtd0DfAQ+pefvOuSCyLPYN8XQ1+sfyW9cyE3iirNmCFZVzLgEAt+gV6YP2RPLy4PkW6V9HuMxewP/Minl2bDOt3zYXyWmw9UYAVceb7Y313KBd/ndQHJ87rbS77v/vMBSz5+QTO62oQl1qEnX8di23nlEgpMIeDaoMRh7NLsfN0EQxGgdMF5UjIvIijuTqUVdVJ7/dbRjEGN/qhbHU4u1R6v9/SizG6ly+2pOQj40Il/jy+FxQKBQxGE5QKhfSb3Cd7M6UabE7Jx4BmwhwAvLnlND6My8B7s4fi1qggvBt7Bh/FZeDbJ0ZhWHcfrD9yDgu+PwrAfFrti0dGorK2HltSCjBlUIAULBf/fALfJeZizcEcJP59Eg5mleCBzw5gvL8S02E+fXb3qgScKazAL8fy8cB13RGXegFvbU1FV3dnHHppIhQK2IwfSMopNfcgWb6wZo8MxfAePjhXWgV3Zyd4W7rmV+7OwHs7zf9Gtw8NxqieXZFeVAFPVzUCPM29aKmF5UgrMs/U+HxfFtYcyEFybpm5PscL8NDoHqgxGPHAZwdxKl+P3NJqjOzhg9iTBVAplfjrpN7wcFGjus6Ih1cfwtHcMhw7p8OK+6OxfEca/rM9Df+cOQgPXNcdRpOQLlh4OKcMM7sBhfoafBB3CnePCEVvvy547ZeTNkHoH5ZrDQHmXpJ3Ys9Ipz7XH8nDG3dEIqekEpkXGo67IzmlSC+qwLcHcxCgdcGnDw3H335sCDSAuTehcW/OLsuYEOux4eWmRlmVAd/sN79XmI8bckqq4OWmxqu3DcL4vn4YGd4V/zt6XgoSTkoFymvr8bf15vf668Q+eGh0DxSV10DjpMJDo3sgq7gSj3+VCABQKRUY6WdCdJgXxvXpBpNJ4L3ZQ7HmQLYUzLu6O2NomDe2nyrE/DVJAIDRvbri7uGhePa7ZADAtEGBGBziicfGhuNTSw/J3cNDMTTMCxonJWrrTbipvz/6BXhgTERX/JZ+EQFaF7x8y0ColArcHBmIX47lY8pAf7wwrR/UKiWiw7xt6jV9cCCmDQqQ7kl2S1QQbokKwqWGhHo1WSZ3DClE7dSzWxd89ejIFl93dVY1Oefbu9GU51ujgnAkuxRrLSP7rdOl6+pNmDYoAP5aF7wwtR/+d+w81h/JkwIKYB4keO/H+wEAkwf4Q6VUYHNKAZ76b8PpnIuVddhz5gJW7LK9qWfjLu58XQ22nCjEwQvmH3LebmqUVhmw58wFbExqGCD56OpDqDHY/rb+W3oxbokKwid7MuGuUeG5SX2hVCrwY6OZB4eySpFbUoW/fJsEg1EgMsQTI8N9cOfKeJRWGrDtr+NQbTBKM58A88wDAHj5pxTsSyvG14+NQrCXK+IzivGhJRS8vyMNff09sGJXOowmga8SshEV4oXXNzV0u+9LL0Z5jQF/+vow4jMu4ui5Mozu1RW7zxTju0RzzYsr6lCgq8HbW1NRV2/CvkJzgFoRl4kzhRXSvvakFePr/dlSXU+c1yOvrAqnG91O4UhOGdIaDexcczAHXm7OuPn9vQjzccPWZ8fhw7gMm8/6zYFsPL02CcUVdfB0VWPPwhvh6aaWTg0AsHkPwHx68N6RoVj4wzGcyjd/aX6+Lwvv70yTAqRJCLx88wD89TvzQFLAfD2N3JIqfLInEwDwy7HzeOC67tiSUiAFAyGAk6UKbNt4AnvSLuJUQTm83dSISzX3Hk3o52czRuL63r7Ym1YsBZSGf8NSfHvQvMzTVQ1dtcHmMxXoa3Dvx/tRbTBibIQvFk3vhxnv7ZPa0aubOzIuVOLTfVnSF/zE/n64sZ8fXtpg7u2ZPTIUL80YgLjUIlzfuxs8Xc0z+K7r6YP/WQagzx4ZhpySSvxmuQ5SVKgX5o4Nh5NKiTfvipLa06ubO5bfOwTFFXXo6+eGi6f2Q6VU2Pz/vq5nV+nvS+8YjHBfd8RnFKPKcorykTHhiGoU2qcNMt989+83D0C/QC3iM4px/6gwaJxUeOz6cOxNK8bNgwOhUCjwzdxRqDGY4OrccA2l124bhDuHheD6CF+bXt1LNb5p6rWEIYXIDu4dGYa0ogpMHhgAF7UK944Mw89Hz+OO6GC8eusgAOZempUPDJO2Gd+nG7p5aHCxohaDQ7wghMCxc+bz+wODtHhv9lBsSSnAZssYEGeVEsN7eCM+4yJe+PEYiivq0EXjhDujg/Flo25z6xfM338+iWqjAiFeLnh6Qh88/+MxrLJ8kVldGlAAYG9aMca8sVN6Xmsw4fo+3fDL0YYvo+RzZXhj82kYjOZvz20nzFf3tf5GujetGAW6ahiMAs4qJeqMJiTllCL2ZCG+srT13W1n8NT4nnj622Rpv2lFFXji60RpjMb2k4XYcboIxRW18HZTQ61Soqi8Fg9+dlDqgdiQlIe1B3NtBi8CwPIdadLAyKp6BfamX8T3lhAT4dcF6UUV+HBXuk1Y2HayQBpDNDTMC0k5ZdIXo/WO3L8ey4e+2oAagwlnCivw4obj0qBHPw8Nisprsel4w+kQXbUB65PO4UJ5rXRfKRe1Uqr9O7Oi8Ny6o0jIuIhZHyXg2DkdlArAy80ZJZZTYFEhnjh6Todv9mejuKIWW04UwFmlhLe7GoX6Wsx4by8qLV+qh7NLkZRTipc2mnsYrHcNX5OhAmD+UrcGHKUC+PzhERgQqMXI1803mu3p644PZkdj1NLtqDGYMDTMC65qFeIzLuL5H48h+2IVlArgP/cMwSOrD0mfMzrMC0cst6XQujjh3buj0M1Dg2AvV+SVVaObhwZfPjoSY/+9S9rm4dE98NKM/nBSKnB9RDd4uauhtVxWwDrmwuqOoSEo0NVgRA8fjOvTDR/vycBv6RfholbikznDmv3CVygU0i8UBoMBjYauNGq3Nx4bG44ATxdMHmgOIJufuR7/O3oezk5KTOjnB6VSgRem9oOu2mDTe3HXsBCbC64tnNIPC6fYvn/jgAIA3u7OTS4G6Ug4JoXIDlzUKvzr9sHSCPshoV5IWTwF/5w5GMoWBrH5aV1wYNEEpP9rOn6aN8bmB9ebd0XCRa3CjX394OWmRld3Z/z38VH427R+AMw9BQDwt2n9MKPRD/Pre/vi9dsHQ6EAKmvNX1r3jgjFmN62U7CnDw6Q/j6ihzf6+nvg6Zsi4NzoB33/wIbBgQ99bp4K2S/AA75dNKirN+HX4w2h5b8HsrHcMtUTAHafKcJGy8yL56f2hYfGCZV1RqnrHQA2JJ3DnSsTUFxRi/6BWmlcT/bFKmiclPDt4ozKOiNetHTn3xIVhCmWLxFrQHF3VqG8ph51RhN6dXPH3LHhmGX50rD+xq9Wmev/p2+SUFZlQLCXK169dSCAht4M6+mi93emI7ekGgFaF3zUKFACwNzrwzEgUIvaepPNOA5rQFkwqQ9+nj/WZpuxlntCvfbLSXwYl4Gi8lq4qlX48pGRGBzsiRX3RePOYSHo1c0d9SZzSPV0VeOzh0ZI4yu6eWjw1aOjMKGfH+pNQjpF8tasSDwyJhwAoLcMJFUqzANHb/8wHmVV5i/UD+6LtmmT9WqnAHB9724Y39cPfloX9LTcuHPR9P7wdFPjm7mj8N7sofjhydG4MzpE+rcBgKfG98KN/fykAZ23RAXhq7mjoLVMi/37zQPgp3WBQqHAvSNC4aRU4J8zByHE202qyV3DQvDKLQOgVimhUCgQ1tVNCijNcXVW4blGNx198LoeWHzLAOx8bjz8PNo+3VapVODvNw/AY9f3lJZ17+qO+Tf1xhPjekn/f58a3wt/m9bvmu3huFrYk0IkEy2Fk5bWuX9UGPakXcDMIcEYGGTuXvZ0UyPu/8ZDrVJKF2b6/OHhOJGnh08XZ8weEWbTg2CeMu2GlfdHIz69GOdzzmLOdWHQupu/hDKLKzG8uzeW3TMEuSUJSC0ox+JbB0rvV1JVh2/25+CO6GC8dVcUvow/i1V7MlBjMOHmyEA8P6UfnluXLH1J3zcqDGsO5MAkIM24ABq+uFVK82+yO08XIT7D/Ft8qI8rIrp1wa7UC9BVGzA42BNfPToSvxzPxz82mrv835s9FIlnS/DJ3ixctPQmzBwaDF21QTo98/RNEag2GKVxF3+b1h+TBvhjb9oFrLP0WGhdnPDClD546aeGwZb3jQrDiHAfqTfDy02ND++Pxn2fNMzMevmWAU2uM/HMhN64KzoEj355CLkl1Ta9IW7OKswdGw53jZPUS9PX3wOvzRyEG9+Ok6agz7uxF27q549h3b3xv780BJqbI4OwfEcaIkM88eH90eYvc6MJJiEwupcvPN3UWHzrQNQZTdC6qDEjMhDTBweiSF+Dj/dkwiQEHhkdjryyKnxvmWob7uuO1Y+MgNZFjZv6dsOpnCLMvbEfVCqVdF2QxrPZPn1oOAp0NRhtCRGN72F1fW9fKBTm00a3Dw3Gs5ZZKu/dOwQJmRcx78YIuKhVWP3oSGRfrMTMRqdD/zKhNx4f11O6bcTbs6KQlFOKSQP82/WF7+qswsOWkEadB0MKUSflp3XBhj+PabLcy832mgk39fPHTf38pecuShWW3zsE+boaTBpgXj51UCAm9PXFpk2ZUnfzW7OicCDrIh4dEw6Nkwpfzx0JXbUB3bs23LDxpekDcNewUEQGe0KpVODRseF4dKztF8Gd0SGIS72Ah0b3wEvT+0vjElzUSvxv/lhMWrZHWvcGyymt0b26SiHli4dHwNvNGV/Gn0VUqHkQo1qlxF3RITh5Xo+xEb6YMjAA/QO02HOmGBcr63BdTx8MDfWCwSgwZaA//Dxc8OzEPpYL52WjT0AXTOhn7oka3t0Hbs4qVNUZ8dasKFzfyxuf7TyBcqHBqJ5d8WBMd6hVSrx66yDsTS/GP2b0h5/WBSN6eONUfjnenhWFqZZxBy9O74cVuzKw6sFhcHN2Qm9/D/w8byw2pxRgULAWt37wGwBgxuBAKUTOHBKEt7edwZ9v7IVwX3cMDvbE8Twd7h4egoVT+jX7bz/vxgiM7tUVQ8O8pZ4OtUpp89t9qI8bvp47qskxE/+3m6BSKqBWKfHrsXwppLw/e6h07Kx6YCg2bdqE6THdUVpjxDvbzqCLxkmaqQaYx2L17Nb87ST8tC74zz1DUF1nxD0jQqVwMTrCVwo1gPnUyaWDQAHY3NcqwNMF02R4uXa6SkQnoNPpBACh0+k6dL91dXVi48aNoq6urkP329mwDqyBEFe2BjWGeunvORcrxb83nxLnSquEEELMXX1IdH/hFzHuzZ2iUFcthBCirKpOLN10Spw837H/563vr6u2/YyJZ0vE7tQiIcQfr4Oh3igqaw1NlptMpha3mb58j+jxt1/E4ewSaZnRaBKF+mrp+fFzZeKtLaeFvvrKH4uGeqN4/deTYtuJApvll9Yg52Kl9G/jKPgzway9dWjv9zd7UojoitM4NfxmHOrjhuenNvQQLL51AMZGdMWdw0Kkeyt5uqql8TQdLbTRxbCsrNfHaA0nlbLFwZct+fzhESjU19hcm0KpVNiMkRgU7IlBwU2ndF8JTiolFk3v/7vrNVczoquBIYWI7CrE281hxgr4a11ke48UIjni7B4iIiKSJYYUIiIikiWGFCIiIpIlhhQiIiKSJYYUIiIikiWGFCIiIpIlhhQiIiKSJYYUIiIikiWGFCIiIpIlhhQiIiKSJYYUIiIikiWGFCIiIpIlhhQiIiKSpU5xF2QhBABAr9d36H4NBgOqqqqg1+uhVqs7dN+dCevAGgCsgRXrwBoArIFVe+tg/d62fo+3VqcIKeXl5QCA0NBQO7eEiIiIWqu8vByenp6t3k4h2hpvriKTyYTz58/Dw8MDCoWiw/ar1+sRGhqK3NxcaLXaDttvZ8M6sAYAa2DFOrAGAGtg1d46CCFQXl6OoKAgKJWtH2HSKXpSlEolQkJCrtj+tVqtQx+EVqwDawCwBlasA2sAsAZW7alDW3pQrDhwloiIiGSJIYWIiIhkyaFDikajwSuvvAKNRmPvptgV68AaAKyBFevAGgCsgZW969ApBs4SERGR43HonhQiIiKSL4YUIiIikiWGFCIiIpIlhhQiIiKSJYcOKStWrECPHj3g4uKCUaNG4eDBg/ZuUpssXboUI0aMgIeHB/z8/DBz5kykpqbarDN+/HgoFAqbP08++aTNOjk5OZgxYwbc3Nzg5+eHhQsXor6+3maduLg4REdHQ6PRICIiAqtXr77SH+8PW7x4cZPP2K9fP+n1mpoazJs3D127dkWXLl1w5513orCw0GYfnb0GPXr0aFIDhUKBefPmAbg2j4M9e/bglltuQVBQEBQKBTZu3GjzuhACL7/8MgIDA+Hq6oqJEyciLS3NZp2SkhLcf//90Gq18PLywty5c1FRUWGzzrFjx3D99dfDxcUFoaGhePPNN5u0Zd26dejXrx9cXFwwePBgbNq0qcM/b0suVweDwYAXXngBgwcPhru7O4KCgjBnzhycP3/eZh/NHT9vvPGGzTpyrsPvHQsPP/xwk883depUm3U6+7HwezVo7ueDQqHAW2+9Ja0jq+NAOKi1a9cKZ2dn8fnnn4sTJ06Ixx9/XHh5eYnCwkJ7N63VpkyZIr744guRkpIikpOTxfTp00VYWJioqKiQ1rnhhhvE448/LvLz86U/Op1Oer2+vl4MGjRITJw4USQlJYlNmzYJX19fsWjRImmdzMxM4ebmJhYsWCBOnjwp3n//faFSqcSWLVuu6udtySuvvCIGDhxo8xkvXLggvf7kk0+K0NBQsWPHDpGYmCiuu+46MXr0aOn1a6EGRUVFNp8/NjZWABC7du0SQlybx8GmTZvESy+9JNavXy8AiA0bNti8/sYbbwhPT0+xceNGcfToUXHrrbeK8PBwUV1dLa0zdepUERUVJfbv3y/27t0rIiIixOzZs6XXdTqd8Pf3F/fff79ISUkR3377rXB1dRWrVq2S1vntt9+ESqUSb775pjh58qT4+9//LtRqtTh+/PgVr4EQl69DWVmZmDhxovjuu+/E6dOnRUJCghg5cqQYNmyYzT66d+8uXn31VZvjo/HPEbnX4feOhYceekhMnTrV5vOVlJTYrNPZj4Xfq0Hjz56fny8+//xzoVAoREZGhrSOnI4Dhw0pI0eOFPPmzZOeG41GERQUJJYuXWrHVnWMoqIiAUDs3r1bWnbDDTeIZ555psVtNm3aJJRKpSgoKJCWrVy5Umi1WlFbWyuEEOL5558XAwcOtNnunnvuEVOmTOnYD9BGr7zyioiKimr2tbKyMqFWq8W6deukZadOnRIAREJCghDi2qjBpZ555hnRq1cvYTKZhBDX/nFw6Q9lk8kkAgICxFtvvSUtKysrExqNRnz77bdCCCFOnjwpAIhDhw5J62zevFkoFAqRl5cnhBDiww8/FN7e3lINhBDihRdeEH379pWe33333WLGjBk27Rk1apT405/+1KGf8Y9o7svpUgcPHhQARHZ2trSse/fuYtmyZS1u05nq0FJIue2221rc5lo7Fv7IcXDbbbeJm266yWaZnI4DhzzdU1dXh8OHD2PixInSMqVSiYkTJyIhIcGOLesYOp0OAODj42Oz/L///S98fX0xaNAgLFq0CFVVVdJrCQkJGDx4MPz9/aVlU6ZMgV6vx4kTJ6R1GtfMuo6capaWloagoCD07NkT999/P3JycgAAhw8fhsFgsGl/v379EBYWJrX/WqmBVV1dHb755hs8+uijNjfmdITjwCorKwsFBQU27fX09MSoUaNs/t29vLwwfPhwaZ2JEydCqVTiwIED0jrjxo2Ds7OztM6UKVOQmpqK0tJSaZ3OUhfA/HNCoVDAy8vLZvkbb7yBrl27YujQoXjrrbdsTvVdC3WIi4uDn58f+vbti6eeegoXL16UXnO0Y6GwsBC//vor5s6d2+Q1uRwHneIGgx2tuLgYRqPR5gcxAPj7++P06dN2alXHMJlMePbZZzFmzBgMGjRIWn7fffehe/fuCAoKwrFjx/DCCy8gNTUV69evBwAUFBQ0Ww/ra5dbR6/Xo7q6Gq6urlfyo/2uUaNGYfXq1ejbty/y8/OxZMkSXH/99UhJSUFBQQGcnZ2b/ED29/f/3c9nfe1y68ilBo1t3LgRZWVlePjhh6VljnAcNGZtc3Ptbfx5/Pz8bF53cnKCj4+PzTrh4eFN9mF9zdvbu8W6WPchJzU1NXjhhRcwe/Zsm5vGPf3004iOjoaPjw/i4+OxaNEi5Ofn49133wXQ+eswdepU3HHHHQgPD0dGRgZefPFFTJs2DQkJCVCpVA53LHz55Zfw8PDAHXfcYbNcTseBQ4aUa9m8efOQkpKCffv22Sx/4oknpL8PHjwYgYGBmDBhAjIyMtCrV6+r3cwrYtq0adLfIyMjMWrUKHTv3h3ff/+9rL44r5bPPvsM06ZNQ1BQkLTMEY4DujyDwYC7774bQgisXLnS5rUFCxZIf4+MjISzszP+9Kc/YenSpdfE5eHvvfde6e+DBw9GZGQkevXqhbi4OEyYMMGOLbOPzz//HPfffz9cXFxslsvpOHDI0z2+vr5QqVRNZnYUFhYiICDATq1qv/nz5+OXX37Brl27EBISctl1R40aBQBIT08HAAQEBDRbD+trl1tHq9XKMgR4eXmhT58+SE9PR0BAAOrq6lBWVmazTuN/82upBtnZ2di+fTsee+yxy653rR8H1jZf7v96QEAAioqKbF6vr69HSUlJhxwbcvqZYg0o2dnZiI2NtelFac6oUaNQX1+Ps2fPArh26mDVs2dP+Pr62hz/jnIs7N27F6mpqb/7MwKw73HgkCHF2dkZw4YNw44dO6RlJpMJO3bsQExMjB1b1jZCCMyfPx8bNmzAzp07m3TDNSc5ORkAEBgYCACIiYnB8ePHbf6DWn+IDRgwQFqncc2s68i1ZhUVFcjIyEBgYCCGDRsGtVpt0/7U1FTk5ORI7b+WavDFF1/Az88PM2bMuOx61/pxEB4ejoCAAJv26vV6HDhwwObfvaysDIcPH5bW2blzJ0wmkxTiYmJisGfPHhgMBmmd2NhY9O3bF97e3tI6cq6LNaCkpaVh+/bt6Nq16+9uk5ycDKVSKZ0CuRbq0Ni5c+dw8eJFm+PfEY4FwNzTOmzYMERFRf3uunY9Dlo1zPYasnbtWqHRaMTq1avFyZMnxRNPPCG8vLxsZjV0Fk899ZTw9PQUcXFxNlPGqqqqhBBCpKeni1dffVUkJiaKrKws8dNPP4mePXuKcePGSfuwTj2dPHmySE5OFlu2bBHdunVrdurpwoULxalTp8SKFStkNf32ueeeE3FxcSIrK0v89ttvYuLEicLX11cUFRUJIcxTkMPCwsTOnTtFYmKiiImJETExMdL210INhDDPVAsLCxMvvPCCzfJr9TgoLy8XSUlJIikpSQAQ7777rkhKSpJmrbzxxhvCy8tL/PTTT+LYsWPitttua3YK8tChQ8WBAwfEvn37RO/evW2mnZaVlQl/f3/x4IMPipSUFLF27Vrh5ubWZMqlk5OTePvtt8WpU6fEK6+8clWnIF+uDnV1deLWW28VISEhIjk52ebnhHWGRnx8vFi2bJlITk4WGRkZ4ptvvhHdunUTc+bM6TR1uFwNysvLxf/93/+JhIQEkZWVJbZv3y6io6NF7969RU1NjbSPzn4s/N7/ByHMU4jd3NzEypUrm2wvt+PAYUOKEEK8//77IiwsTDg7O4uRI0eK/fv327tJbQKg2T9ffPGFEEKInJwcMW7cOOHj4yM0Go2IiIgQCxcutLk+hhBCnD17VkybNk24uroKX19f8dxzzwmDwWCzzq5du8SQIUOEs7Oz6Nmzp/QecnDPPfeIwMBA4ezsLIKDg8U999wj0tPTpderq6vFn//8Z+Ht7S3c3NzE7bffLvLz82320dlrIIQQW7duFQBEamqqzfJr9TjYtWtXs8f/Qw89JIQwT0P+xz/+Ifz9/YVGoxETJkxoUpuLFy+K2bNniy5dugitViseeeQRUV5ebrPO0aNHxdixY4VGoxHBwcHijTfeaNKW77//XvTp00c4OzuLgQMHil9//fWKfe5LXa4OWVlZLf6csF5D5/Dhw2LUqFHC09NTuLi4iP79+4vXX3/d5gtcCHnX4XI1qKqqEpMnTxbdunUTarVadO/eXTz++ONNfjHt7MfC7/1/EEKIVatWCVdXV1FWVtZke7kdBwohhGhd3wsRERHRleeQY1KIiIhI/hhSiIiISJYYUoiIiEiWGFKIiIhIlhhSiIiISJYYUoiIiEiWGFKIiIhIlhhSiIiISJYYUoiIiEiWGFKIiIhIlhhSiIiISJYYUoiIiEiW/h+uOoIBq9Jt+AAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot([i[\"step\"] for i in record_dict[\"train\"][::50]], [i[\"loss\"] for i in record_dict[\"train\"][::50]],\n",
    "         label=\"train\")\n",
    "plt.grid()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "869bd29321bc0af7",
   "metadata": {},
   "source": [
    "## 推理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "efc609cabc2d694a",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "execution": {
     "iopub.execute_input": "2025-01-24T12:07:44.140178Z",
     "iopub.status.busy": "2025-01-24T12:07:44.139821Z",
     "iopub.status.idle": "2025-01-24T12:07:45.238632Z",
     "shell.execute_reply": "2025-01-24T12:07:45.237910Z",
     "shell.execute_reply.started": "2025-01-24T12:07:44.140155Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_348/2407769902.py:29: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  model.load_state_dict(torch.load(\"checkpoints/05_text_generation.ckpt\", map_location=device))\n",
      "  0%|          | 0/1000 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "All: I"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 1/1000 [00:00<02:43,  6.11it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " would me yet behold my lady,\n",
      "To let me see them and to make me know\n",
      "The name of Hereford that do prisone"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 11%|█         | 106/1000 [00:00<00:01, 492.57it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "r\n",
      "Defore the common people.\n",
      "\n",
      "CORIOLANUS:\n",
      "How? what?\n",
      "\n",
      "AsTept my dear banish'd; and that is Titus?\n",
      "\n",
      "MARCIUS:\n",
      "The hope th"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 22%|██▏       | 224/1000 [00:00<00:01, 765.16it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "at is the sun that seems descend;\n",
      "And therefore hence, that she may long live to you\n",
      "I would say, so good"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 33%|███▎      | 329/1000 [00:00<00:00, 866.01it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " Montagues: let him be\n",
      "called between brief; and so be Duke of Rome,\n",
      "And oppressit her that with a goodly tribune:"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 44%|████▍     | 443/1000 [00:00<00:00, 958.98it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "I know not how to come to the death of Carlia feeds,\n",
      "And with thy lips and great sorrow to deny him:\n",
      "'Tis not impossible y"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 57%|█████▋    | 566/1000 [00:00<00:00, 1046.72it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ou weep; for once\n",
      "I know not how to come and prayers.\n",
      "\n",
      "CORIOLANUS:\n",
      "The gods bless you with another.\n",
      "\n",
      "Pedant:\n",
      "B"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 68%|██████▊   | 676/1000 [00:00<00:00, 1061.40it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "y government that be brief, though not I cannot,\n",
      "Scorn against them, where's the common malice.\n",
      "\n",
      "WARWICK:\n",
      "And"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 78%|███████▊  | 785/1000 [00:00<00:00, 1053.27it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " lo, where then? do you find those things,\n",
      "Which I will hence to-morrow in this land:\n",
      "The general suit is no"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 89%|████████▉ | 893/1000 [00:00<00:00, 1029.52it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "t his restories,\n",
      "Young, closely take our leave to meet your grace,\n",
      "The ancient value of the meaning it\n",
      "Upon"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:01<00:00, 925.44it/s]\n"
     ]
    }
   ],
   "source": [
    "def generate_text(model, start_string, max_len=1000, temperature=1.0, stream=True):\n",
    "    # tempareture: 控制随机性，越高越随机，越低越相似\n",
    "    # stream: 是否流式生成，即一次生成一部分字符，返回一个字符，再输入下一个字符，直到达到max_len\n",
    "    input_eval = torch.tensor([char2idx[s] for s in start_string]).to(dtype=torch.int64, device=device).reshape(1, -1)\n",
    "    hidden = None\n",
    "    text_generated = []  # 用来保存生成的文本\n",
    "    model.eval()\n",
    "    pbar = tqdm(range(max_len))  # 进度条\n",
    "    print(start_string, end=\"\")\n",
    "    # no_grad是一个上下文管理器，用于指定在其中的代码块中不需要计算梯度。在这个区域内，不会记录梯度信息，用于在生成文本时不影响模型权重\n",
    "    with torch.no_grad():\n",
    "        for i in pbar:  # 控制进度条\n",
    "            logits, hidden = model(input_eval, hidden=hidden)\n",
    "            # 温度采样，较高的温度会增加预测结果的多样性，较低的温度则更加保守。\n",
    "            # 取-1的目的是只要最后，拼到原有的输入上\n",
    "            logits = logits[0, -1, :] / temperature\n",
    "            probs = F.softmax(logits, dim=-1)  # 算为概率分布\n",
    "            # 从概率分布中抽取一个样本,取概率较大的那些\n",
    "            idx = torch.multinomial(probs, 1).item()\n",
    "            # 把idx转为tensor\n",
    "            # 将采样到的字符索引作为下一个时间步的输入。\n",
    "            input_eval = torch.Tensor([idx]).to(dtype=torch.int64, device=device).reshape(1, -1)\n",
    "            text_generated.append(idx)\n",
    "            if stream:\n",
    "                print(idx2char[idx], end=\"\", flush=True)\n",
    "    return \"\".join([idx2char[i] for i in text_generated])\n",
    "\n",
    "# 03_text_generation.ckpt\n",
    "model.load_state_dict(torch.load(\"checkpoints/05_text_generation.ckpt\", map_location=device))\n",
    "start_string = \"All: \" #这里就是开头，什么都可以\n",
    "res = generate_text(model, start_string, max_len=1000, temperature=0.5, stream=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "1be041d2e307e3eb",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-24T12:07:49.711560Z",
     "iopub.status.busy": "2025-01-24T12:07:49.710961Z",
     "iopub.status.idle": "2025-01-24T12:07:49.715350Z",
     "shell.execute_reply": "2025-01-24T12:07:49.714764Z",
     "shell.execute_reply.started": "2025-01-24T12:07:49.711533Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"I would me yet behold my lady,\\nTo let me see them and to make me know\\nThe name of Hereford that do prisoner\\nDefore the common people.\\n\\nCORIOLANUS:\\nHow? what?\\n\\nAsTept my dear banish'd; and that is Titus?\\n\\nMARCIUS:\\nThe hope that is the sun that seems descend;\\nAnd therefore hence, that she may long live to you\\nI would say, so good Montagues: let him be\\ncalled between brief; and so be Duke of Rome,\\nAnd oppressit her that with a goodly tribune:\\nI know not how to come to the death of Carlia feeds,\\nAnd with thy lips and great sorrow to deny him:\\n'Tis not impossible you weep; for once\\nI know not how to come and prayers.\\n\\nCORIOLANUS:\\nThe gods bless you with another.\\n\\nPedant:\\nBy government that be brief, though not I cannot,\\nScorn against them, where's the common malice.\\n\\nWARWICK:\\nAnd lo, where then? do you find those things,\\nWhich I will hence to-morrow in this land:\\nThe general suit is not his restories,\\nYoung, closely take our leave to meet your grace,\\nThe ancient value of the meaning it\\nUpon\""
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
